{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7abb4368",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"make variations of input image\"\"\"\n",
    "\n",
    "import argparse, os, sys, glob\n",
    "import PIL\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from PIL import Image\n",
    "from tqdm import tqdm, trange\n",
    "from itertools import islice\n",
    "from einops import rearrange, repeat\n",
    "from torchvision.utils import make_grid\n",
    "from torch import autocast\n",
    "from contextlib import nullcontext\n",
    "import time\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "sys.path.append(os.path.dirname(sys.path[0]))\n",
    "from ldm.util import instantiate_from_config\n",
    "from ldm.models.diffusion.ddim import DDIMSampler\n",
    "from ldm.models.diffusion.plms import PLMSSampler\n",
    "\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "def chunk(it, size):\n",
    "    it = iter(it)\n",
    "    return iter(lambda: tuple(islice(it, size)), ())\n",
    "\n",
    "\n",
    "def load_model_from_config(config, ckpt, verbose=False):\n",
    "    print(f\"Loading model from {ckpt}\")\n",
    "    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
    "    if \"global_step\" in pl_sd:\n",
    "        print(f\"Global Step: {pl_sd['global_step']}\")\n",
    "    sd = pl_sd[\"state_dict\"]\n",
    "    model = instantiate_from_config(config.model)\n",
    "    m, u = model.load_state_dict(sd, strict=False)\n",
    "    if len(m) > 0 and verbose:\n",
    "        print(\"missing keys:\")\n",
    "        print(m)\n",
    "    if len(u) > 0 and verbose:\n",
    "        print(\"unexpected keys:\")\n",
    "        print(u)\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def load_img(path):\n",
    "    image = Image.open(path).convert(\"RGB\")\n",
    "    w, h = image.size\n",
    "    print(f\"loaded input image of size ({w}, {h}) from {path}\")\n",
    "    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n",
    "    image = image.resize((512, 512), resample=PIL.Image.LANCZOS)\n",
    "    image = np.array(image).astype(np.float32) / 255.0\n",
    "    image = image[None].transpose(0, 3, 1, 2)\n",
    "    image = torch.from_numpy(image)\n",
    "    return 2.*image - 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0abda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "config=\"configs/stable-diffusion/v1-inference.yaml\"\n",
    "ckpt=\"models/sd/sd-v1-4.ckpt\"\n",
    "config = OmegaConf.load(f\"{config}\")\n",
    "model = load_model_from_config(config, f\"{ckpt}\")\n",
    "sampler = DDIMSampler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ff8a8f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(prompt = '', content_dir = '', style_dir='',ddim_steps = 50,strength = 0.5, model = None, seed=42):\n",
    "    ddim_eta=0.0\n",
    "    n_iter=1\n",
    "    C=4\n",
    "    f=8\n",
    "    n_samples=1\n",
    "    n_rows=0\n",
    "    scale=10.0\n",
    "    \n",
    "    precision=\"autocast\"\n",
    "    outdir=\"outputs/img2img-samples\"\n",
    "    seed_everything(seed)\n",
    "\n",
    "\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "    outpath = outdir\n",
    "\n",
    "    batch_size = n_samples\n",
    "    n_rows = n_rows if n_rows > 0 else batch_size\n",
    "    data = [batch_size * [prompt]]\n",
    "\n",
    "\n",
    "    sample_path = os.path.join(outpath, \"samples\")\n",
    "    os.makedirs(sample_path, exist_ok=True)\n",
    "    base_count = len(os.listdir(sample_path))\n",
    "    grid_count = len(os.listdir(outpath)) + 10\n",
    "    \n",
    "    style_image = load_img(style_dir).to(device)\n",
    "    style_image = repeat(style_image, '1 ... -> b ...', b=batch_size)\n",
    "    style_latent = model.get_first_stage_encoding(model.encode_first_stage(style_image))  # move to latent space\n",
    "\n",
    "    content_name =  content_dir.split('/')[-1].split('.')[0]\n",
    "    content_image = load_img(content_dir).to(device)\n",
    "    content_image = repeat(content_image, '1 ... -> b ...', b=batch_size)\n",
    "    content_latent = model.get_first_stage_encoding(model.encode_first_stage(content_image))  # move to latent space\n",
    "\n",
    "    init_latent = content_latent\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)\n",
    "\n",
    "    assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
    "    t_enc = int(strength * ddim_steps)\n",
    "    print(f\"target t_enc is {t_enc} steps\")\n",
    "\n",
    "    precision_scope = autocast if precision == \"autocast\" else nullcontext\n",
    "    with torch.no_grad():\n",
    "        with precision_scope(\"cuda\"):\n",
    "            with model.ema_scope():\n",
    "                tic = time.time()\n",
    "                all_samples = list()\n",
    "                for n in trange(n_iter, desc=\"Sampling\"):\n",
    "                    for prompts in tqdm(data, desc=\"data\"):\n",
    "                        uc = None\n",
    "                        if scale != 1.0:\n",
    "                            uc = model.get_learned_conditioning(batch_size * [\"\"], style_image)\n",
    "                        if isinstance(prompts, tuple):\n",
    "                            prompts = list(prompts)\n",
    "\n",
    "                        c= model.get_learned_conditioning(prompts, style_image)\n",
    "\n",
    "                        # img2img\n",
    "\n",
    "                        # stochastic encode\n",
    "                        # z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "\n",
    "                        # stochastic inversion\n",
    "                        t_enc = int(strength * 1000) \n",
    "                        x_noisy = model.q_sample(x_start=init_latent, t=torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        model_output = model.apply_model(x_noisy, torch.tensor([t_enc]*batch_size).to(device), c)\n",
    "                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device),\\\n",
    "                                                          noise = model_output, use_original_steps = True)\n",
    "            \n",
    "                        t_enc = int(strength * ddim_steps)\n",
    "                        samples = sampler.decode(z_enc, c, t_enc, \n",
    "                                                unconditional_guidance_scale=scale,\n",
    "                                                 unconditional_conditioning=uc,)\n",
    "                        print(z_enc.shape, uc.shape, t_enc)\n",
    "\n",
    "                        # txt2img\n",
    "            #             noise  =torch.randn_like(content_latent)\n",
    "            #             samples, intermediates =sampler.sample(ddim_steps,1,(4,512,512),c,verbose=False, eta=1.,x_T = noise,\n",
    "            #    unconditional_guidance_scale=scale,\n",
    "            #    unconditional_conditioning=uc,)\n",
    "\n",
    "                        x_samples = model.decode_first_stage(samples)\n",
    "\n",
    "                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
    "\n",
    "                        for x_sample in x_samples:\n",
    "                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
    "                            base_count += 1\n",
    "                        all_samples.append(x_samples)\n",
    "\n",
    "                # additionally, save as grid\n",
    "                grid = torch.stack(all_samples, 0)\n",
    "                grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
    "                grid = make_grid(grid, nrow=n_rows)\n",
    "\n",
    "                # to image\n",
    "                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
    "                output = Image.fromarray(grid.astype(np.uint8))\n",
    "                output.save(os.path.join(outpath, content_name+'-'+prompt+f'-{grid_count:04}.png'))\n",
    "                # Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))\n",
    "                grid_count += 1\n",
    "\n",
    "                toc = time.time()\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08dd2dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.cpu()\n",
    "model.embedding_manager.load('./logs/{log_dir}/checkpoints/embeddings.pt')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca318f6c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "main(prompt = '*', \\\n",
    "     content_dir = '{image_name}', \\\n",
    "     style_dir = '{image_name}', \\\n",
    "     ddim_steps = 50, \\\n",
    "     strength = 0.7, \\\n",
    "     seed=42, \\\n",
    "     model = model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa770b54",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "4bfdbc5ecf268fe8cbe1003c5e2c130e872c62898120a6963bc33993ee6594f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
